import matplotlib.pyplot as plt
from tslearn.barycenters import euclidean_barycenter
from matplotlib.backends.backend_pdf import PdfPages
from constants import WEEK_MEASUREMENTS,FEAT_IDX,FEAT_LABEL

def plot_and_save_load_profiles(train_characterize, train_raw_ns, output, algorithm, base_directory, scaler, cluster_results):
    """
    median: set to True for K-Medoids results, otherwise set to False
    """
    if algorithm == "kmedoids":
        median = True
    else:
        median = False
    grouped_output = output.groupby('cluster')
    medoids = cluster_results["medoids"]
    for feature in list(FEAT_IDX.keys()):
        f = FEAT_IDX[feature]
        with PdfPages(base_directory + '/raw_'+feature+'_plot.pdf') as pdf:
            for l in sorted(set(output['cluster'])):
                fig, ax = plt.subplots(1,1,figsize=(8,3))
                cl = grouped_output.get_group(l)
                profile = []
                for i in cl.index:
                    data = train_characterize[cl.loc[i,'index']].T[f]
                    plt.plot(data, c='gray', alpha=0.3)
                    if not median:
                        profile.append(data)
                if l == -1:
                    None
                elif median:
                    plt.plot(scaler.inverse_transform(train_raw_ns[output.loc[medoids[l],'index']]).T[f], c='r')
                else:
                    plt.plot(euclidean_barycenter(profile), c='r')
                for i in range(int(WEEK_MEASUREMENTS/7),WEEK_MEASUREMENTS-2,int(WEEK_MEASUREMENTS/7)):
                    plt.axvline(i, color='k', linestyle='dotted')
                days = ['Mon','Tue','Wed','Thu','Fri','Sat','Sun']
                idx = 0
                for i in range(7):
                    plt.text(i/9+0.16, 0.05, days[idx], transform=plt.gcf().transFigure)
                    idx += 1
                plt.title('Cluster '+str(l), fontsize=16)
                plt.grid(False)
                plt.xlim(0,WEEK_MEASUREMENTS)
                plt.xticks([])

                if feature == 'Home':
                    plt.ylim(-0.1,1.1)
                elif feature == 'SOC':
                    plt.ylim(-5,105)
                elif feature == 'delta_soc':
                    plt.ylim(-12,35)
                elif feature == 'dod':
                    plt.ylim(-5,105)
                elif feature == 'charging_power_level':
                    plt.ylim(-0.1,3.1)
                elif feature == 'charging_energy_kwh':
                    plt.ylim(-5,85)
                elif feature == 'weekly_cycle':
                    plt.ylim(0,13)

                ax.spines['right'].set_visible(False)
                ax.spines['top'].set_visible(False)
                plt.ylabel(FEAT_LABEL[feature])
                pdf.savefig()
                plt.close()
    return True




